//
// Created by hjhu on 2024/7/2.
//

#ifndef TEST_WAV2LIP_CPP_FAN_H
#define TEST_WAV2LIP_CPP_FAN_H

#include <torch/torch.h>
#include "ConvBlock.h"
#include "BaseModule.h"

namespace coastal {

    using namespace torch;
    using namespace std;

    class FanImpl : public coastal::BaseModuleImpl {

    public:
        FanImpl(int64_t
                num_modules = 1
        );

        torch::Tensor forward(torch::Tensor x);

    private :
        int64_t num_modules;

        torch::nn::Conv2d conv1;
        torch::nn::BatchNorm2d bn1;
        std::shared_ptr<coastal::ConvBlockImpl> conv2, conv3, conv4;
    };

} // coastal

#endif //TEST_WAV2LIP_CPP_FAN_H
